Scaled Bias Add support after CUBLAS GGEMM#2885
Scaled Bias Add support after CUBLAS GGEMM#2885vthumbe1503 wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
…imized and uses scales now Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ed_linear_integration
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Greptile SummaryThis PR adds optional per-row scale support to Confidence Score: 5/5Safe to merge; all remaining findings are P2 style/improvement suggestions with no runtime risk for current callers. The core scaled-bias logic is correctly implemented: fmaf argument order matches documented semantics, shared-memory cumsum is correctly initialized and synchronized, and the empty-tensor sentinel correctly disables scaling. The two P2 findings (dead pre-loop bias load; missing tensor_offsets guard) do not affect current callers since bias GroupedTensors are always packed in the Python bindings. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — dead pre-loop load and missing tensor_offsets guard in nvte_grouped_bias_add. Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as Python (gemm.py)
participant Ext as C++ Extension (gemm.cpp)
participant NVTE as nvte_grouped_gemm
participant BiasKernel as nvte_grouped_bias_add
Py->>Ext: general_grouped_gemm_for_grouped_tensor(A, B, out, bias, bias_scale)
Ext->>Ext: prepare_grouped_gemm_config(alpha, beta, ...)
Ext->>NVTE: nvte_grouped_gemm(A, B, C=D, D, alpha, beta, ...)
NVTE-->>Ext: D = alpha * A @ B + beta * C
alt bias is not None
Ext->>BiasKernel: nvte_grouped_bias_add(D, bias, scale)
Note over BiasKernel: Build shared cumsum for row-to-tensor map
BiasKernel->>BiasKernel: grouped_bias_add_kernel UseScale=true/false
Note over BiasKernel: D[row,col] += bias[col] * scale[row]
BiasKernel-->>Ext: D updated in-place
end
Ext-->>Py: D (updated)
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| const size_t tensor_idx = blockIdx.y; | ||
| if (tensor_idx >= num_tensors) return; | ||
|
|
||
| const int64_t n = d_meta.last_dims ? d_meta.last_dims[0] : d_meta.uniform_last; |
There was a problem hiding this comment.
Hardcoded index
[0] instead of [tensor_idx]
d_meta.last_dims[0] works only because the pre-launch NVTE_CHECK(outputD->all_same_last_dim() ...) enforces a uniform last dimension. Using the hardcoded index removes the per-tensor correctness at a glance — a future reader (or a refactor that relaxes the uniform check) would not immediately see why [0] is used instead of [tensor_idx]. A comment linking this to the uniformity invariant would make this self-documenting.
| const int64_t n = d_meta.last_dims ? d_meta.last_dims[0] : d_meta.uniform_last; | |
| const int64_t n = d_meta.last_dims ? d_meta.last_dims[0] // uniform across tensors (checked) | |
| : d_meta.uniform_last; |
| int64_t scale_row_offset = 0; | ||
| if constexpr (UseScale) { | ||
| if (d_meta.first_dims) { | ||
| for (size_t i = 0; i < tensor_idx; i++) { | ||
| scale_row_offset += d_meta.first_dims[i]; | ||
| } | ||
| } else { | ||
| scale_row_offset = static_cast<int64_t>(tensor_idx) * d_meta.uniform_first; | ||
| } | ||
| } |
There was a problem hiding this comment.
Redundant per-thread
scale_row_offset loop
Every thread in the block (all 256 of them) independently computes scale_row_offset by iterating up to tensor_idx times over d_meta.first_dims. Since tensor_idx == blockIdx.y, all threads in a block produce the same value. For large num_tensors, moving this into shared memory (computed once by thread 0 and shared) would avoid the redundant iterations. The broadcast access pattern through L1 is benign for small num_tensors, but is worth noting for scalability.
| std::optional<SwizzledGroupedScales> maybe_swizzle_grouped_tensor(GroupedTensorWrapper &input, | ||
| bool rowwise_usage, | ||
| bool columnwise_usage) { | ||
| if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { | ||
| if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && | ||
| input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { | ||
| return std::nullopt; | ||
| } |
There was a problem hiding this comment.
Unrelated FP4 swizzle change — should be documented
This guard extension (adding NVTE_NVFP4_1D_SCALING) is a separate fix that enables grouped-tensor scale swizzling for FP4 inputs; it is unrelated to the Scaled Bias Add feature described in the PR title. nvte_swizzle_grouped_scaling_factors does handle FP4 in swizzle.cu, so the change is mechanically correct, but it would be helpful to document the motivation in the PR description or add a comment here explaining why FP4 also needs this path.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
timmoon10
left a comment
There was a problem hiding this comment.
My big question is whether the kernel implementation changes are providing a perf benefit.
| py::handle A, bool transa, py::handle B, bool transb, py::handle D, py::object bias, | ||
| at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, | ||
| bool use_split_accumulator, int math_sm_count) { | ||
| at::Tensor bias_scale, at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, |
There was a problem hiding this comment.
We should avoid the overhead of constructing a tensor when bias_scale isn't needed. std::optional also communicates the intent more clearly.
| at::Tensor bias_scale, at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, | |
| std::optional<at::Tensor> bias_scale, at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, |
| if bias_scale is None: | ||
| bias_scale = torch.empty(0, dtype=torch.float32, device=device) | ||
|
|
There was a problem hiding this comment.
We can avoid this overhead by making the tex function take an optional argument.
| if bias_scale is None: | |
| bias_scale = torch.empty(0, dtype=torch.float32, device=device) |
| void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, | ||
| cudaStream_t stream); | ||
| const NVTETensor scale, cudaStream_t stream); |
There was a problem hiding this comment.
I think it makes more sense to create a separate API for nvte_grouped_scaled_bias_add. Grouped bias is a natural generalization of linear layer biases, but grouped scaled bias is less intuitive (especially that the biases are per-group, but the scales are per-token) and it should be treated as more exotic.
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import Parameter | ||
|
|
There was a problem hiding this comment.
Nit: Is there a reason we're reordering? If the import order causes problems, then that's a bug we need to fix. Otherwise, this ordering seems strangely unmotivated and haphazard. It's also considered good Python style to put third party imports before local imports (PEP 8).
| constexpr int kMaxTensors = 257; | ||
| __shared__ int cumsum[kMaxTensors]; |
There was a problem hiding this comment.
The variable name is wrong.
| constexpr int kMaxTensors = 257; | |
| __shared__ int cumsum[kMaxTensors]; | |
| constexpr int kMaxTensors = 256; | |
| __shared__ int cumsum[kMaxTensors + 1]; |
| // Binary search for the starting row's tensor. | ||
| int tensor_idx; | ||
| { | ||
| int lo = 0, hi = num_tensors; | ||
| while (lo < hi) { | ||
| int mid = (lo + hi) >> 1; | ||
| if (cumsum[mid + 1] <= row_start) | ||
| lo = mid + 1; | ||
| else | ||
| hi = mid; | ||
| } | ||
| tensor_idx = lo; | ||
| } | ||
| int bias_idx = tensor_idx * n; |
There was a problem hiding this comment.
Have we benchmarked whether this binary search is any better than just scanning through the tensors. Computing the cumsums is still O(n), so we're not improving the asymptotics. We're also introducing thread syncs and shared memory accesses.
| const auto *b_vec = reinterpret_cast<const VecType *>(bias_ptr + col); | ||
| VecStorage b_in; | ||
| b_in.scratch_.aligned = *b_vec; | ||
| b_in.scratch_.aligned = *reinterpret_cast<const VecType *>(bias + bias_idx + col); |
There was a problem hiding this comment.
This value is immediately wiped out in the loop. I guess the compiler might be smart enough not to do an unnecessary memory access, but it makes the code harder to read.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: